{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "toc": true
   },
   "source": [
    "<h1>Table of Contents<span class=\"tocSkip\"></span></h1>\n",
    "<div class=\"toc\"><ul class=\"toc-item\"><li><span><a href=\"#Load-Network\" data-toc-modified-id=\"Load-Network-1\"><span class=\"toc-item-num\">1&nbsp;&nbsp;</span>Load Network</a></span></li><li><span><a href=\"#Explore-Directions\" data-toc-modified-id=\"Explore-Directions-2\"><span class=\"toc-item-num\">2&nbsp;&nbsp;</span>Explore Directions</a></span><ul class=\"toc-item\"><li><span><a href=\"#Interactive\" data-toc-modified-id=\"Interactive-2.1\"><span class=\"toc-item-num\">2.1&nbsp;&nbsp;</span>Interactive</a></span></li></ul></li></ul></div>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "is_stylegan_v1 = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pathlib import Path\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import sys\n",
    "import os\n",
    "from datetime import datetime\n",
    "from tqdm import tqdm\n",
    "\n",
    "# ffmpeg installation location, for creating videos\n",
    "plt.rcParams['animation.ffmpeg_path'] = str('/usr/bin/ffmpeg')\n",
    "import ipywidgets as widgets\n",
    "from ipywidgets import interact, interact_manual\n",
    "from IPython.display import display\n",
    "from ipywidgets import Button\n",
    "\n",
    "os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'\n",
    "\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "# StyleGAN2 Repo\n",
    "sys.path.append('/tf/notebooks/stylegan2')\n",
    "\n",
    "# StyleGAN Utils\n",
    "from stylegan_utils import load_network, gen_image_fun, synth_image_fun, create_video\n",
    "# v1 override\n",
    "if is_stylegan_v1:\n",
    "    from stylegan_utils import load_network_v1 as load_network\n",
    "    from stylegan_utils import gen_image_fun_v1 as gen_image_fun\n",
    "    from stylegan_utils import synth_image_fun_v1 as synth_image_fun\n",
    "\n",
    "import run_projector\n",
    "import projector\n",
    "import training.dataset\n",
    "import training.misc\n",
    "\n",
    "# Data Science Utils\n",
    "sys.path.append(os.path.join(*[os.pardir]*3, 'data-science-learning'))\n",
    "\n",
    "from ds_utils import generative_utils"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "res_dir = Path.home() / 'Documents/generated_data/stylegan'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Load Network"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "MODELS_DIR = Path.home() / 'Documents/models/stylegan2'\n",
    "MODEL_NAME = 'original_ffhq'\n",
    "SNAPSHOT_NAME = 'stylegan2-ffhq-config-f'\n",
    "\n",
    "Gs, Gs_kwargs, noise_vars = load_network(str(MODELS_DIR / MODEL_NAME / SNAPSHOT_NAME) + '.pkl')\n",
    "\n",
    "Z_SIZE = Gs.input_shape[1]\n",
    "IMG_SIZE = Gs.output_shape[2:]\n",
    "IMG_SIZE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "img = gen_image_fun(Gs, np.random.randn(1, Z_SIZE), Gs_kwargs, noise_vars)\n",
    "plt.imshow(img)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Explore Directions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_direction_grid(dlatent, direction, coeffs):\n",
    "    fig, ax = plt.subplots(1, len(coeffs), figsize=(15, 10), dpi=100)\n",
    "    \n",
    "    for i, coeff in enumerate(coeffs):\n",
    "        new_latent = (dlatent.copy() + coeff*direction)\n",
    "        ax[i].imshow(synth_image_fun(Gs, new_latent, Gs_kwargs, randomize_noise=False))\n",
    "        ax[i].set_title(f'Coeff: {coeff:0.1f}')\n",
    "        ax[i].axis('off')\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# load learned direction\n",
    "direction = np.load('/tf/media/datasets/stylegan/learned_directions.npy')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "nb_latents = 5\n",
    "\n",
    "# generate dlatents from mapping network\n",
    "dlatents = Gs.components.mapping.run(np.random.rand(nb_latents, Z_SIZE), None, truncation_psi=1.)\n",
    "\n",
    "for i in range(nb_latents):\n",
    "    plot_direction_grid(dlatents[i:i+1], direction, np.linspace(-2, 2, 5))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Interactive"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Setup plot image\n",
    "dpi = 100\n",
    "fig, ax = plt.subplots(dpi=dpi, figsize=(7, 7))\n",
    "fig.subplots_adjust(left=0, right=1, bottom=0, top=1, hspace=0, wspace=0)\n",
    "plt.axis('off')\n",
    "im = ax.imshow(gen_image_fun(Gs, np.random.randn(1, Z_SIZE),Gs_kwargs, noise_vars, truncation_psi=1))\n",
    "\n",
    "#prevent any output for this cell\n",
    "plt.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# fetch attributes names\n",
    "directions_dir = MODELS_DIR / MODEL_NAME / 'directions' / 'set01'\n",
    "attributes = [e.stem for e in directions_dir.glob('*.npy')]\n",
    "\n",
    "# get names or projected images\n",
    "data_dir = res_dir / 'projection' / MODEL_NAME / SNAPSHOT_NAME / '20200215_192547'\n",
    "entries = [p.name for p in data_dir.glob(\"*\") if p.is_dir()]\n",
    "entries.remove('tfrecords')\n",
    "\n",
    "# set target latent to play with\n",
    "dlatents = Gs.components.mapping.run(np.random.rand(1, Z_SIZE), None, truncation_psi=0.5)\n",
    "target_latent = dlatents[0:1]\n",
    "#target_latent = np.array([np.load(\"/out_4/image_latents2000.npy\")])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "\n",
    "@interact\n",
    "def i_direction(attribute=attributes, \n",
    "                entry=entries,\n",
    "                coeff=(-10., 10.)):\n",
    "    direction = np.load(directions_dir / f'{attribute}.npy')\n",
    "    target_latent = np.array([np.load(data_dir / entry / \"image_latents1000.npy\")])\n",
    "    \n",
    "    new_latent_vector = target_latent.copy() + coeff*direction\n",
    "    im.set_data(synth_image_fun(Gs, new_latent_vector, Gs_kwargs, True))\n",
    "    ax.set_title('Coeff: %0.1f' % coeff)\n",
    "    display(fig)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "StyleGAN",
   "language": "python",
   "name": "stylegan"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": true,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": true
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
